package cn.az13js.satool;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.FileReader;
import java.util.LinkedList;

public class Util {

    private static final int CACHE_SIZE = 32768;

    public static double[] nDArray(int a)
    {
        return new double[a];
    }

    public static double[][] nDArray(int a, int b)
    {
        return new double[a][b];
    }

    public static double[][][] nDArray(int a, int b, int c)
    {
        return new double[a][b][c];
    }

    public static double[][][][] nDArray(int a, int b, int c, int d)
    {
        return new double[a][b][c][d];
    }

    public static String commandParam(String[] args, String match)
    {
        boolean isMatch = false;
        for (String arg : args) {
            if (isMatch) {
                return arg;
            }
            if (match.equals(arg)) {
                isMatch = true;
            }
        }
        return "";
    }

    public static boolean commandParamExists(String[] args, String match)
    {
        for (String arg : args) {
            if (match.equals(arg)) {
                return true;
            }
        }
        return false;
    }

    public static void message(String msg) {
        System.out.println(msg);
    }

    public static boolean fileTransforCode(String src, String dst) {
        boolean result = true;
        FileInputStream srcFile = null;
        FileOutputStream dstFile = null;
        try {
            srcFile = new FileInputStream(src);
            dstFile = new FileOutputStream(dst);
            byte[] byteData = new byte[1048576];
            int readDataLength;
            while (-1 != (readDataLength = srcFile.read(byteData))) {
                for (int i = 0; i < readDataLength; i++) {
                    byteData[i] = (byte)(127 - byteData[i]);
                }
                dstFile.write(byteData, 0, readDataLength);
            }
        } catch (IOException e) {
            result = false;
        } finally {
            if (null != srcFile) {
                try {
                    srcFile.close();
                } catch (IOException e) {
                    result = false;
                } finally {
                    if (null != dstFile) {
                        try {
                            dstFile.close();
                        } catch (IOException e) {
                            result = false;
                        }
                    }
                }
            }
        }
        return result;
    }

    public static double normalDistribution(double mean, double standardDeviation) {
        // Box-Muller
        double U1 = 1 - Math.random();
        double U2 = 1 - Math.random();
        double R = Math.sqrt(-2 * Math.log(U2));
        double THETA = 2 * Math.PI * U1;
        double Z = R * Math.cos(THETA);
        return mean + (Z * standardDeviation);
    }

    public static LinkedList<Double> readData(String fileName) throws IOException {
        class StatsMachine {
            /*
                  处理的符号共有：\n , 其它符号 文件终止符号
                  默认状态 s=0 遇到其它符号，放入fileContents，状态不变
                  默认状态 s=0 遇到\n符号，弹出fileContents内容，然后设置fileContents为空，状态不变
                  默认状态 s=0 遇到,符号，弹出fileContents内容，然后设置fileContents为空，状态变成1
                  默认状态 s=0 遇到文件终止符号，弹出fileContents内容，状态不变

                  状态 s=1 遇到\n符号，状态变成0
                  状态 s=1 遇到文件终止符号符号，状态变成0
                  状态 s=1 遇到其它符号，状态不变
                  状态 s=1 遇到,符号，状态不变
             */
            private int stats = 0;
            private StringBuilder fileContents;

            public StatsMachine(StringBuilder fileContents) {
                this.fileContents = fileContents;
            }

            public boolean handle(char ch) {
                if (0 == stats) {
                    if ('\n' == ch) {
                        return true;
                    }
                    if (',' == ch) {
                        stats = 1;
                        return true;
                    }
                    fileContents.append(ch);
                    return false;
                }
                if (1 == stats) {
                    if ('\n' == ch) {
                        stats = 0;
                        return false;
                    }
                    if (',' == ch) {
                        return false;
                    }
                    return false;
                }
                return false;
            }

            public boolean finish() {
                if (0 == stats) {
                    return true;
                }
                if (1 == stats) {
                    stats = 0;
                    return false;
                }
                return false;
            }

            public double getNumberValue() {
                String s = fileContents.toString();
                fileContents.setLength(0);
                return Double.parseDouble(s);
            }
        }

        FileReader csv = new FileReader(fileName);
        char[] cache = new char[CACHE_SIZE];
        StringBuilder fileContents = new StringBuilder(CACHE_SIZE);
        LinkedList<Double> dataList = new LinkedList<Double>();
        int readedSize = 0;
        StatsMachine tmpStatsMachine = new StatsMachine(fileContents);
        try {
            while ((readedSize = csv.read(cache, 0, CACHE_SIZE)) > 0) {
                for (char ch : cache) {
                    if (tmpStatsMachine.handle(ch)) {
                        dataList.add(tmpStatsMachine.getNumberValue());
                    }
                }
                if (readedSize < CACHE_SIZE) { // 此次读取的文件内容大小小于缓存区大小，那么表示已经读取到文件尾部
                    if (tmpStatsMachine.finish()) {
                        dataList.add(tmpStatsMachine.getNumberValue());
                    }
                    break;
                }
            }
        } catch (NumberFormatException e) { // 当tmpStatsMachine.getNumberValue()遇到非数字内容时会抛出异常。此时表示数据已经结束，文件后买年可能时空格或者回车，无需继续
        }
        csv.close();
        return dataList;
    }

    public static double[] ar(double[] sequence, int p, double lr, int epoch) {
        double[] a = Distribution.generateData(p, 0, 1.0);
        double[] delta = Util.nDArray(p);
        double[] s = Util.nDArray(sequence.length - p);
        double temp;
        for (int i = 0; i < a.length; i++) {
            message("a["+i+"]=" + a[i]);
        }
        for (int ep = 0; ep < epoch; ep++) {
            temp = 0;
            for (int j = p + 1; j <= sequence.length; j++) {
                for (int k = 1; k <= p; k++) {
                    temp += a[k - 1] * sequence[j - k - 1];
                }
                s[j - (p + 1)] = sequence[j - 1] - temp;
            }
            for (int i = 1; i <= p; i++) {
                delta[i - 1] = 0;
                for (int j = p + 1; j <= sequence.length; j++) {
                    delta[i - 1] += s[j - (p + 1)] * sequence[j - i - 1];
                }
                delta[i - 1] /= p - sequence.length;
            }
            for (int i = 0; i < p; i++) {
                a[i] -= lr * delta[i];
            }
        }
        return a;
    }

    public static double gaussianIntegral(double bottom, double top, int sp, double average, double standardDeviation) {
        double a = bottom - average;
        double b = top - average;
        double sum = 0;
        double x;
        double y;
        for (int i = 0; i < sp; i++) {
            x = a + (b - a) * i / sp;
            y = Math.exp(x * x / -2 / standardDeviation / standardDeviation);
            sum += y * (b - a) / sp;
        }
        return sum / Math.sqrt(2 * Math.PI) / standardDeviation;
    }

    public static double integral(double a, double b, int sp) {
        double sum = 0;
        double x;
        double y;
        for (int i = 0; i < sp; i++) {
            x = a + (b - a) * i / sp;
            y = func(x);
            sum += y * (b - a) / sp;
        }
        return sum;
    }

    private static double func(double x) {
        return Math.exp(-(x * x)); // 在此计算
    }

}
